import torch
import numpy as np

"""
自定义评价指标
"""


# MCR_MSE评估: mean column_wise root mean squared error
def MCR_MSE(scores, targets, mask):
    scores = scores.to(torch.float32)
    targets = targets.to(torch.float32)
    # 只评价没有mask掉的部分
    masked_targets = torch.masked_select(targets, mask)
    targets = masked_targets.view(-1, targets.size(-1))
    masked_scores = torch.masked_select(scores, mask)
    scores = masked_scores.view(-1, scores.size(-1))
    mcrmse = torch.mean(torch.sqrt(torch.mean(torch.square((scores - targets)), axis=0)))
    return mcrmse
